local super = require "GraphLayer"

ScatterGraphLayer = super:new()

local styleStamps = {
    circle = CirclePointStamp,
    hollow = StrokedCirclePointStamp,
    triangle = TrianglePointStamp,
    square = SquarePointStamp,
    diamond = DiamondPointStamp,
    cross = CrossPointStamp,
    star = StarPointStamp,
}

local styleStrings = inverttable(styleStamps)

local regressionClasses = {
    linear = LinearRegression,
    quadratic = QuadraticRegression,
    cubic = CubicRegression,
    exponential = ExponentialRegression,
    logarithmic = LogarithmicRegression,
    power = PowerRegression,
    sigmoid = SigmoidRegression,
}

local defaults = {
    size = 25,
    style = 'circle',
}

local nilDefaults = {
    'x', 'y', 'value', 'paint', 'regression',
}

local freeGetterNames = {'x', 'y'}
local constrainedGetterNames = {'#', 'value'}
local commonGetterNames = {'size', 'paint', 'style'}

local freeInspectorInfo = {
    {'KeyArtifact', {'x'}, 'X'},
    {'KeyArtifact', {'y'}, 'Y'},
}

local constrainedInspectorInfo = {
    {'KeyArtifact', {'value'}, 'Value'},
}

local commonInspectorInfo = {
    {'Size', {'size', sizes = 'getSizes:', iconFunction = 'getSizeIconFunction:'}, 'Size'},
    {'Color', {'getPaint:setPaint', custom = 'hasExplicitPaint:'}, 'Color'},
}

local tonumber = tonumber
local _sqrt = math.sqrt

function ScatterGraphLayer:new()
    self = super.new(self)
    
    for k, v in pairs(defaults) do
        self:addProperty(k, v)
    end
    for _, k in pairs(nilDefaults) do
        self:addProperty(k)
    end
    
    self._regression = nil
    self._regressionInvalidator = function()
        self._regression = nil
    end
    
    return self
end

function ScatterGraphLayer:unarchived()
    local dataset = self:getDataset()
    if dataset then
        if self:isPositionConstrained() then
            if self:getProperty('value') == nil then
                local valueFields = self:peerPropertyKeyArtifactValues(ScatterGraphLayer, 'value')
                local valueField = dataset:pickField('number', valueFields)
                if valueField then
                    self:setProperty('value', KeyArtifact:new(valueField))
                end
            end
        else
            if self:getProperty('x') == nil and self:getProperty('y') == nil then
                local xFields = self:peerPropertyKeyArtifactValues(GraphLayer, 'x')
                local xField = xFields[#xFields] or dataset:pickField(self:getParent():getHorizontalAxis():getPreferredType())
                local yFields = self:peerPropertyKeyArtifactValues(ScatterGraphLayer, 'y')
                yFields[#yFields + 1] = xField
                local yField = dataset:pickField(self:getParent():getVerticalAxis():getPreferredType(), yFields)
                if xField and yField then
                    self:setProperty('x', KeyArtifact:new(xField))
                    self:setProperty('y', KeyArtifact:new(yField))
                end
            end
        end
    end
    super.unarchived(self)
end

function ScatterGraphLayer:unarchiveStyle(archived)
    local style = unarchive(archived)
    self:setProperty('style', style)
end

function ScatterGraphLayer:getGetterPieceNames(constrained)
    local result = {}
    if constrained then
        appendtables(result, constrainedGetterNames)
    else
        appendtables(result, freeGetterNames)
    end
    appendtables(result, commonGetterNames)
    return result
end

function ScatterGraphLayer:getInspectorInfo(constrained)
    local result = {}
    if constrained then
        appendtables(result, constrainedInspectorInfo)
    else
        appendtables(result, freeInspectorInfo)
    end
    appendtables(result, commonInspectorInfo)
    return result
end

function ScatterGraphLayer:getStyleStamps()
    return {
        CirclePointStamp,
        StrokedCirclePointStamp,
        TrianglePointStamp,
        SquarePointStamp,
        DiamondPointStamp,
        CrossPointStamp,
        StarPointStamp,
    }
end

function ScatterGraphLayer:getRegressionClasses(parent)
    local result = {
        {nil, 'None'},
        {},
        {'linear', 'Linear'},
        {'quadratic', 'Quadratic'},
        {'cubic', 'Cubic'},
        {'exponential', 'Exponential'},
        {'sigmoid', 'Sigmoid'},
    }
    if parent:getHorizontalAxis():requiresNumberValues() then
        appendtables(result, {
            {'logarithmic', 'Logarithmic'},
            {'power', 'Power'},
        })
    end
    return result
end

function ScatterGraphLayer:getInspectors()
    local list = super.getInspectors(self)
    local inspector, hook
    inspector = Inspector:new{
        title = 'Style',
        type = 'Stamp',
        constraint = function()
            return ScatterGraphLayer:getStyleStamps()
        end,
    }
    hook = Hook:new(
        function()
            return styleStamps[self:getProperty('style')]
        end,
        function(value)
            self:setProperty('style', styleStrings[value])
        end)
    self:getPropertyHook('style'):addObserver(hook)
    inspector:addHook(hook)
    list:add(inspector)
    local parent = self:getParent()
    if parent:getHorizontalAxis():isQuantitative() and parent:getVerticalAxis():requiresNumberValues() then
        inspector = Inspector:new{
            title = 'Trend Line',
            type = 'List.Group',
            target = function()
                local list = List:new()
                local inspector
                inspector = Inspector:new{
                    title = 'Trend Line',
                    type = 'Class',
                    constraint = function()
                        return ScatterGraphLayer:getRegressionClasses(parent)
                    end,
                }
                inspector:addHook(self:getPropertyHook('regression'))
                list:add(inspector)
                inspector = Inspector:new{
                    title = 'Show Equation',
                    type = 'trendline',
                }
                if parent:getHorizontalAxis():requiresNumberValues() then
                    inspector:addHook(parent:getShowRegressionEquationsHook(), 'equations')
                end
                inspector:addHook(parent:getShowRegressionR2sHook(), 'r2s')
                list:add(inspector)
                return list
            end,
        }
        list:add(inspector)
    end
    return list
end

function ScatterGraphLayer:getSizes()
    return { 2, 4, 6, 10, 16, 25, 40, 64, 100, 160, 250 }
end

function ScatterGraphLayer:getSizeIconFunction()
    return function(canvas, size)
        local rect = Rect:new(canvas:metrics():rect())
        if size < rect:height() then
            local inset = (rect:height() - size) / 2
            canvas:setPaint(Color.gray(0, 0.125))
                :fill(Path.oval(rect))
                :setPaint(Color.black)
                :fill(Path.oval(rect:insetXY(inset, inset)))
        else
            canvas:setPaint(Color.black)
                :fill(Path.oval(rect))
        end
    end
end

function ScatterGraphLayer:iterateValues(orientation, mapFunction)
    local dataset = self:getDataset()
    local propertyName = 'value'
    if not self:isPositionConstrained() then
        if orientation == Graph.horizontalOrientation then
            propertyName = 'x'
        else
            propertyName = 'y'
        end
    end
    local sequence = self:getPropertySequence(propertyName, dataset)
    for _, value in sequence:iter() do
        mapFunction(value)
    end
end

function ScatterGraphLayer:setRegression(regressionClass)
    self:setProperty('regression', regressionClass)
end

function ScatterGraphLayer:getRegression(propertySequence, normalizePosition)
    if self._regression == nil then
        self._datasetHook:addObserver(self._regressionInvalidator)
        self:getPropertyHook('x'):addObserver(self._regressionInvalidator)
        self:getPropertyHook('y'):addObserver(self._regressionInvalidator)
        self:getPropertyHook('regression'):addObserver(self._regressionInvalidator)
        self._regression = self:calculateRegression(propertySequence, normalizePosition)
    end
    return self._regression
end

function ScatterGraphLayer:calculateRegression(propertySequence, normalizePosition)
    local regression
    local regressionClass = regressionClasses[self:getProperty('regression')]
    if regressionClass then
        regression = regressionClass:new()
        regression:init()
        propertySequence:each(function(position, value)
            regression:update(normalizePosition(position), tonumber(value))
        end)
        if not regression:finish() then
            regression = false
        end
    end
    return regression
end

function ScatterGraphLayer:draw(canvas, rect, propertySequence, xScaler, yScaler)
    local parent = self:getParent()
    local baseSize = parent:getBaseSize()
    local defaultPaint = self:getPaint()
    local isVertical = self:getOrientation() == Graph.verticalOrientation
    local regressionPaint
    canvas:clipRect(rect:expand{left = 4, bottom = 4, right = 4, top = 4})
    
    local regression
    local normalizePosition
    local canShowRegressionEquations
    if parent:getVerticalAxis():requiresNumberValues() then
        if parent:getHorizontalAxis():requiresNumberValues() then
            normalizePosition = tonumber
            canShowRegressionEquations = true
        elseif parent:getHorizontalAxis():requiresDateValues() then
            normalizePosition = function(date)
                if type(date) == 'userdata' then
                    return date:numeric()
                end
            end
        end
        regression = self:getRegression(propertySequence, normalizePosition)
    end
    
    propertySequence:each(function(position, value, area, paint, styleString)
        local x, y
        if isVertical then
            x, y = xScaler(position), yScaler(value)
        else
            y, x = yScaler(position), xScaler(value)
        end
        local styleStamp = styleStamps[styleString]
        local radius = _sqrt(tonumber(area) or 10) / 2
        if styleStamp and x and y and rect.left <= x and x <= rect.right and rect.bottom <= y and y <= rect.top then
            styleStamp(canvas, x, y, radius * baseSize, paint or defaultPaint)
        end
        regressionPaint = regressionPaint or paint
    end)
    
    if regression then
        canvas:clipRect(rect)
            :setPaint(regressionPaint or defaultPaint)
            :setThickness(baseSize)
        local axis = parent:getHorizontalAxis()
        local path
        for xPt = rect:minx(), rect:maxx() do
            local x, y
            x = axis:scaled(rect, xPt)
            y = regression:getValue(normalizePosition(x))
            if y then
                x, y = xScaler(x), yScaler(y)
                if x and y then
                    if path then
                        path:addLine{x = x, y = y}
                    else
                        path = Path.point{x = x, y = y}
                    end
                else
                    canvas:stroke(path)
                    path = nil
                end
            end
        end
        regressionPaint = regressionPaint or defaultPaint
        canvas:stroke(path)
        
        local showEquations = canShowRegressionEquations and parent:showRegressionEquations()
        local showR2s = parent:showRegressionR2s()
        if showEquations or showR2s then
            local pieces = {}
            if showEquations then
                pieces[#pieces + 1] = regression:getEquation()
            end
            if showR2s then
                local R2 = regression:getR2()
                if R2 >= 0 and R2 <= 1 then
                    pieces[#pieces + 1] = 'R^{2} = ' .. string.format('%0.4f', R2)
                end
            end
            local styledString = superscript(string.gsub(table.concat(pieces, '     '), '-', string.minus), parent:getLayerLabelFont())
            local stringRect = styledString:measure()
            
            local previousStamp, previousWidth, previousHeight = parent:getOverlay('regression')
            local margin, padx, pady = 4, 8, 4
            local stamp = function(canvas, rect, isRecursed)
                if not isRecursed then
                    local path = Path.rect(rect:insetXY(margin, margin), pady)
                    canvas:preserve(function(canvas)
                        canvas:setOpacity(0.7)
                            :setPaint(parent:getBackgroundPaint())
                            :fill(path)
                            :setOpacity(0.5)
                            :setPaint(parent:getAxisPaint())
                            :setThickness(1 * baseSize)
                            :stroke(path)
                    end)
                end
                if previousStamp then
                    local previousRect = rect:copy()
                    previousRect.bottom = previousRect.top - previousHeight
                    previousStamp(canvas, previousRect, true)
                end
                local truncatedStyledString = styledString:truncate(rect:width() - margin * 2 - padx * 1.5)
                canvas:setPaint(parent:getBackgroundPaint())
                    :setThickness(1 * baseSize)
                    :strokeText(truncatedStyledString, rect.left + margin + padx, rect.bottom + margin + pady - stringRect:miny(), 0)
                canvas:setPaint(regressionPaint)
                    :drawText(truncatedStyledString, rect.left + margin + padx, rect.bottom + margin + pady - stringRect:miny(), 0)
            end
            local width = math.max(previousWidth or 0, stringRect:width() + margin * 2 + padx * 2)
            local height = (previousHeight or margin * 2 + pady * 2) + stringRect:height() * 1.25
            parent:setOverlay('regression', stamp, width, height, 0, 1)
        end
    end
end

return ScatterGraphLayer
